{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rBvZfIhJO-Wc"
      },
      "source": [
        "Licensed under the Apache License, Version 2.0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "D54H_RXGHwUG"
      },
      "source": [
        "# Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1tgIZB6DGTWX"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import scipy \n",
        "import math\n",
        "import json\n",
        "import shutil\n",
        "from copy import deepcopy\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import pandas as pd\n",
        "import seaborn as sns\n",
        "from google.colab import files\n",
        "import importlib\n",
        "\n",
        "from sklearn.linear_model import LinearRegression\n",
        "from methods import m1, m2, m3, m4"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eSbe97BrHuk4"
      },
      "outputs": [],
      "source": [
        "M1 = m1.Estimator\n",
        "M2 = m2.Estimator\n",
        "M3 = m3.Estimator\n",
        "M4 = m4.Estimator"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LtDWqlDerDBr"
      },
      "source": [
        "# Load data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wNYtPfv0iSIH"
      },
      "outputs": [],
      "source": [
        "!rm benchmark.vision.csv\n",
        "f = files.upload()\n",
        "df_vision = pd.read_csv('benchmark.vision.csv')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YI7NKxxrfrrY"
      },
      "outputs": [],
      "source": [
        "!rm benchmark.lang.csv\n",
        "f = files.upload()\n",
        "df_lang = pd.read_csv('benchmark.lang.csv')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sv5vqVaCitjb"
      },
      "outputs": [],
      "source": [
        "# last column indicates whether the data should be used for training or testing.\n",
        "df_all = pd.concat([df_vision, df_lang])\n",
        "df_all"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1QJsiJV2ivhN"
      },
      "source": [
        "# Evaluate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N9Gr4P_MP-ts"
      },
      "outputs": [],
      "source": [
        "np.random.seed(2021)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fYnZYeyTrcEo"
      },
      "outputs": [],
      "source": [
        "def get_error(slaw, x, y):\n",
        "  \"\"\"Evaluate the scaling law estimator slaw on the test data (x, y).\n",
        "\n",
        "  Args:\n",
        "    x: 1d array containing data sizes.\n",
        "    y: 1d array containing errors/losses.\n",
        "  \"\"\"\n",
        "  yp = np.array([slaw.predict_loss(xi) for xi in x])\n",
        "  error = (np.log(yp) - np.log(y)) ** 2\n",
        "  err_mu = np.mean(error)\n",
        "  err_std = np.sqrt(err_mu + np.std(error) / (len(yp)**0.5)) - np.sqrt(err_mu)\n",
        "  # return mean and std error\n",
        "  return np.sqrt(err_mu), err_std\n",
        "\n",
        "def create_dir(dir_name):\n",
        "  # if dir exists, remove it and its contents\n",
        "  try:\n",
        "    shutil.rmtree(dir_name)\n",
        "  except:\n",
        "    pass\n",
        "  os.mkdir(dir_name)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dOSXWdz7uGjW"
      },
      "outputs": [],
      "source": [
        "scaling_laws = {}\n",
        "errors = {}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_PkGk5U4i2Dg"
      },
      "source": [
        "## Image Classification"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5jimwEw_i6px"
      },
      "outputs": [],
      "source": [
        "domain = 'IC'\n",
        "df = df_all[df_all['Domain'] == domain]\n",
        "tasks = set(df['Task'])\n",
        "models = set(df['Model'])\n",
        "print('Tasks: ', tasks)\n",
        "print('Models: ', models)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JM4O3ibJjCrx"
      },
      "outputs": [],
      "source": [
        "downstream_groups = {\n",
        "    'Birds': ['bird_5', 'bird_10', 'bird_25'],\n",
        "    'CIFAR100': ['c_5', 'c_10', 'c_25'],\n",
        "    'Caltech101': ['cal_5', 'cal_10', 'cal_25'],\n",
        "    'ImageNet': ['inet_5', 'inet_10', 'inet_25'],\n",
        "}\n",
        "\n",
        "for model in models:\n",
        "  for group in downstream_groups:\n",
        "    for downstream in downstream_groups[group]:\n",
        "      key = (domain, model, downstream)\n",
        "      print()\n",
        "      print(key)\n",
        "      # fetch training data\n",
        "      df_subset1 = df[(df['Model'] == model) \u0026 (df['Task'] == downstream) \u0026 (df['Training'] == 1)]\n",
        "      x_vals = np.array(df_subset1['Seen Examples'])\n",
        "      y_vals = np.array(df_subset1['Loss'])\n",
        "      fit_values = {x: y for x, y in zip(x_vals, y_vals)}\n",
        "      # fetch test data\n",
        "      df_subset0 = df[(df['Model'] == model) \u0026 (df['Task'] == downstream) \u0026 (df['Training'] == 0)]\n",
        "      x_test = np.array(df_subset0['Seen Examples'])\n",
        "      y_test = np.array(df_subset0['Loss'])\n",
        "\n",
        "      # train all estimators\n",
        "      scaling_laws[key] = {}\n",
        "      errors[key] = {}\n",
        "      for mode in ['M1', 'M2', 'M3', 'M4']:\n",
        "        print(mode)\n",
        "        if mode == 'M1':\n",
        "          scaling_laws[key][mode] = M1(fit_values)\n",
        "        elif mode == 'M2':\n",
        "          scaling_laws[key][mode] = M2(fit_values)\n",
        "        elif mode == 'M3':\n",
        "          scaling_laws[key][mode] = M3(fit_values)\n",
        "        elif mode == 'M4':\n",
        "          scaling_laws[key][mode] = M4(fit_values, err_inf=None, err_0=1.0,\n",
        "                                       update_err_0=True, up_bound=1.0)\n",
        "        # fit\n",
        "        scaling_laws[key][mode].estimate_scaling_params(verbose=0,\n",
        "                                                        max_iterations=10_000)\n",
        "        # report\n",
        "        if mode in ['M1', 'M2']:\n",
        "          print('beta, c, err_inf =\\t\\t %.2f, %0.2f, %0.2f' % (\n",
        "                scaling_laws[key][mode].beta,\n",
        "                scaling_laws[key][mode].c,\n",
        "                scaling_laws[key][mode].err_inf             \n",
        "                )\n",
        "          )\n",
        "        elif mode == 'M3':\n",
        "          print('beta, c, gamma =\\t\\t %.2f, %0.2f, %0.2f' %(\n",
        "                scaling_laws[key][mode].beta,\n",
        "                scaling_laws[key][mode].c,\n",
        "                scaling_laws[key][mode].gamma              \n",
        "                )\n",
        "          )\n",
        "        else:\n",
        "          print('beta, c, alpha, err_inf =\\t %.2f, %0.2f, %0.2f, %0.2f' %(\n",
        "                scaling_laws[key][mode].beta,\n",
        "                scaling_laws[key][mode].c,\n",
        "                scaling_laws[key][mode].alpha,\n",
        "                scaling_laws[key][mode].err_inf\n",
        "                )\n",
        "          )                    \n",
        "\n",
        "        # record error\n",
        "        errors[key][mode] = get_error(scaling_laws[key][mode],\n",
        "                                       x_test, y_test)\n",
        "        print('Extrapolation Loss =\\t\\t %.4f +- %.5f' %(\n",
        "            errors[key][mode][0], errors[key][mode][1])\n",
        "        )\n",
        "        print()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eY_soh1N1ko3"
      },
      "outputs": [],
      "source": [
        "create_dir('image_classification')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1HPFiuwCzps3"
      },
      "outputs": [],
      "source": [
        "sns.set_theme(context='paper', style='whitegrid', palette='colorblind',\n",
        "              font_scale=1.75, rc={'lines.linewidth': 2})\n",
        "for key in scaling_laws:\n",
        "  domain, model, downstream = key\n",
        "  fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(20, 4))\n",
        "  fig_id = -1\n",
        "  miny = 1  # used to rescale the y-axis in the figures.\n",
        "  for mode in ['M1', 'M2', 'M3', 'M4']:\n",
        "    fig_id += 1\n",
        "    # fetch training data\n",
        "    df_subset1 = df[(df['Model'] == model) \u0026 (df['Task'] == downstream) \u0026 (df['Training'] == 1)]\n",
        "    x_vals = np.array(df_subset1['Seen Examples'])\n",
        "    y_vals = np.array(df_subset1['Loss'])\n",
        "    fit_values = {x: y for x, y in zip(x_vals, y_vals)}\n",
        "    # fetch test data\n",
        "    df_subset0 = df[(df['Model'] == model) \u0026 (df['Task'] == downstream) \u0026 (df['Training'] == 0)]\n",
        "    x_test = np.array(df_subset0['Seen Examples'])\n",
        "    y_test = np.array(df_subset0['Loss'])\n",
        "\n",
        "    law = scaling_laws[key][mode]\n",
        "\n",
        "    axes[fig_id].scatter(list(fit_values.keys()), list(fit_values.values()),\n",
        "                         c='black', alpha=0.5)\n",
        "    axes[fig_id].scatter(x_test, y_test,  c='tab:orange')\n",
        "    xt, yt = law.loss_curve(min(df_subset1['Seen Examples']),  # min of fitting data\n",
        "                            max(df_subset0['Seen Examples']) * 3)  # max of extrapolation data\n",
        "\n",
        "    axes[fig_id].plot(xt, yt, color='tab:blue')\n",
        "    axes[fig_id].set_xscale('log')\n",
        "    axes[fig_id].set_title(mode.upper())\n",
        "    axes[fig_id].xaxis.set_ticklabels([])\n",
        "    if mode == 'M1':\n",
        "      axes[fig_id].set_ylabel('Error Rate')\n",
        "    axes[fig_id].set_xlabel('Examples Seen')\n",
        "\n",
        "    miny = min(int(min(yt)*10)/10, miny)\n",
        "\n",
        "  plt.setp(axes, ylim=(miny, 1))\n",
        "  plt.setp(axes, xlim=(min(df_subset1['Seen Examples']), max(df_subset0['Seen Examples'])*3))\n",
        "  subtitle = key[1] + ', ' + key[2]\n",
        "  fig.suptitle(subtitle)\n",
        "  fig.tight_layout()\n",
        "  plt.show()\n",
        "  filename = 'image_classification/' + model.replace('/', '_') + '_' + downstream + '.pdf'\n",
        "  fig.savefig(filename, dpi=200)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KDcjCEBP1zCo"
      },
      "outputs": [],
      "source": [
        "# download figures\n",
        "shutil.make_archive('image_classification', 'zip', 'image_classification')\n",
        "%download_file image_classification.zip"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mHujgpCv2_qj"
      },
      "outputs": [],
      "source": [
        "# win probabilities\n",
        "wins = np.array([0., 0., 0., 0.])\n",
        "for key in errors:\n",
        "  e = np.array([errors[key][mode][0] for mode in ['M1', 'M2', 'M3', 'M4']])\n",
        "  for k in range(4):\n",
        "    if np.isnan(e[k]):  # if it fails, set its loss to max\n",
        "      e[k] = 1\n",
        "    e[k] = int(e[k] * 1_000) / 1_000 # compare performance up to 3 decimal places.\n",
        "  v = np.min(e)\n",
        "  wins  = wins + (e == v) / sum(e == v)  # if multiplle winners, divide score\n",
        "print(wins / len(errors))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NPVFmQN3OHjm"
      },
      "source": [
        "## NMT"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AqLspHUmEWLq"
      },
      "outputs": [],
      "source": [
        "domain = 'NMT'\n",
        "df = df_all[df_all['Domain'] == domain]\n",
        "tasks = set(df['Task'])\n",
        "models = set(df['Model'])\n",
        "print('Tasks: ', tasks)\n",
        "print('Models: ', models)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m7C9kU2kEgMK"
      },
      "outputs": [],
      "source": [
        "df"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Vd3-6CeaEgmt"
      },
      "outputs": [],
      "source": [
        "for model in models:\n",
        "  key = (domain, model)\n",
        "  print()\n",
        "  print(key)\n",
        "  # fetch training data\n",
        "  df_subset1 = df[(df['Model'] == model) \u0026 (df['Training'] == 1)]\n",
        "  x_vals = np.array(df_subset1['Seen Examples'])\n",
        "  y_vals = np.array(df_subset1['Loss'])\n",
        "\n",
        "  fit_values = {x: y for x, y in zip(x_vals, y_vals)}\n",
        "  # fetch test data\n",
        "  df_subset0 = df[(df['Model'] == model) \u0026 (df['Training'] == 0)]\n",
        "  x_test = np.array(df_subset0['Seen Examples'])\n",
        "  y_test = np.array(df_subset0['Loss'])\n",
        "\n",
        "  # train all estimators\n",
        "  scaling_laws[key] = {}\n",
        "  errors[key] = {}\n",
        "  for mode in ['M1', 'M2', 'M3', 'M4']:\n",
        "    print(mode)\n",
        "    if mode == 'M1':\n",
        "      scaling_laws[key][mode] = M1(fit_values)\n",
        "    elif mode == 'M2':\n",
        "      scaling_laws[key][mode] = M2(fit_values)\n",
        "    elif mode == 'M3':\n",
        "      scaling_laws[key][mode] = M3(fit_values)\n",
        "    elif mode == 'M4':\n",
        "      scaling_laws[key][mode] = M4(fit_values, err_0=1.0, update_err_0=True,\n",
        "                                   up_bound=None)  # no upper bound since this is log-preplexity\n",
        "    # fit\n",
        "    scaling_laws[key][mode].estimate_scaling_params(verbose=0)\n",
        "    # report\n",
        "    if mode in ['M1', 'M2']:\n",
        "      print('beta, c, err_inf =\\t\\t %.2f, %0.2f, %0.2f' % (\n",
        "            scaling_laws[key][mode].beta,\n",
        "            scaling_laws[key][mode].c,\n",
        "            scaling_laws[key][mode].err_inf             \n",
        "            )\n",
        "      )\n",
        "    elif mode == 'M3':\n",
        "      print('beta, c, gamma =\\t\\t %.2f, %0.2f, %0.2f' %(\n",
        "            scaling_laws[key][mode].beta,\n",
        "            scaling_laws[key][mode].c,\n",
        "            scaling_laws[key][mode].gamma              \n",
        "            )\n",
        "      )\n",
        "    else:\n",
        "      print('beta, c, alpha, err_inf =\\t %.2f, %0.2f, %0.2f, %0.2f' %(\n",
        "            scaling_laws[key][mode].beta,\n",
        "            scaling_laws[key][mode].c,\n",
        "            scaling_laws[key][mode].alpha,\n",
        "            scaling_laws[key][mode].err_inf\n",
        "            )\n",
        "      )                    \n",
        "\n",
        "    # record error\n",
        "    errors[key][mode] = get_error(scaling_laws[key][mode],\n",
        "                                    x_test, y_test)\n",
        "    print('Extrapolation Loss =\\t\\t %.4f +- %.5f' %(\n",
        "        errors[key][mode][0], errors[key][mode][1])\n",
        "    )\n",
        "    print()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ciMRzpX1LbDb"
      },
      "source": [
        "## Language Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gcpwMuX4Ljdr"
      },
      "outputs": [],
      "source": [
        "domain = 'LM'\n",
        "df = df_all[df_all['Domain'] == domain]\n",
        "tasks = set(df['Task'])\n",
        "models = set(df['Model'])\n",
        "print('Tasks: ', tasks)\n",
        "print('Models: ', models)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1WO2cr5uLm8g"
      },
      "outputs": [],
      "source": [
        "df"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V4XhGTysLnUg"
      },
      "outputs": [],
      "source": [
        "for model in models:\n",
        "  key = (domain, model)\n",
        "  print()\n",
        "  print(key)\n",
        "  # fetch training data\n",
        "  df_subset1 = df[(df['Model'] == model) \u0026 (df['Training'] == 1)]\n",
        "  x_vals = np.array(df_subset1['Seen Examples'])\n",
        "  y_vals = np.array(df_subset1['Loss'])  # rescaled already\n",
        "\n",
        "  fit_values = {x: y for x, y in zip(x_vals, y_vals)}\n",
        "  # fetch test data\n",
        "  df_subset0 = df[(df['Model'] == model) \u0026 (df['Training'] == 0)]\n",
        "  x_test = np.array(df_subset0['Seen Examples'])\n",
        "  y_test = np.array(df_subset0['Loss'])\n",
        "\n",
        "  # train all estimators\n",
        "  scaling_laws[key] = {}\n",
        "  errors[key] = {}\n",
        "  for mode in ['M1', 'M2', 'M3', 'M4']:\n",
        "    print(mode)\n",
        "    if mode == 'M1':\n",
        "      scaling_laws[key][mode] = M1(fit_values)\n",
        "    elif mode == 'M2':\n",
        "      scaling_laws[key][mode] = M2(fit_values)\n",
        "    elif mode == 'M3':\n",
        "      scaling_laws[key][mode] = M3(fit_values)\n",
        "    elif mode == 'M4':\n",
        "      scaling_laws[key][mode] = M4(fit_values, err_0=1.0,\n",
        "                                   update_err_0=True,\n",
        "                                   up_bound=None)  # no upper bound since this is cross-entropy loss\n",
        "    # fit\n",
        "    scaling_laws[key][mode].estimate_scaling_params(verbose=0)\n",
        "    # report\n",
        "    if mode in ['M1', 'M2']:\n",
        "      print('beta, c, err_inf =\\t\\t %.2f, %0.2f, %0.2f' % (\n",
        "            scaling_laws[key][mode].beta,\n",
        "            scaling_laws[key][mode].c,\n",
        "            scaling_laws[key][mode].err_inf             \n",
        "            )\n",
        "      )\n",
        "    elif mode == 'M3':\n",
        "      print('beta, c, gamma =\\t\\t %.2f, %0.2f, %0.2f' %(\n",
        "            scaling_laws[key][mode].beta,\n",
        "            scaling_laws[key][mode].c,\n",
        "            scaling_laws[key][mode].gamma              \n",
        "            )\n",
        "      )\n",
        "    else:\n",
        "      print('beta, c, alpha, err_inf =\\t %.2f, %0.2f, %0.2f, %0.2f' %(\n",
        "            scaling_laws[key][mode].beta,\n",
        "            scaling_laws[key][mode].c,\n",
        "            scaling_laws[key][mode].alpha,\n",
        "            scaling_laws[key][mode].err_inf\n",
        "            )\n",
        "      )            \n",
        "\n",
        "    # record error\n",
        "    errors[key][mode] = get_error(scaling_laws[key][mode],\n",
        "                                    x_test, y_test)\n",
        "    print('Extrapolation Loss =\\t\\t %.4f +- %.5f' %(\n",
        "        errors[key][mode][0], errors[key][mode][1])\n",
        "    )\n",
        "    print()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-yTEyZoGfcFX"
      },
      "outputs": [],
      "source": [
        "sns.set_theme(context='paper', style='whitegrid', palette='colorblind',\n",
        "              font_scale=1.25, rc={'lines.linewidth': 2})\n",
        "\n",
        "fig, axes = plt.subplots(nrows=5, ncols=4, figsize=(20,10))\n",
        "row_id = -1\n",
        "for model in models:\n",
        "  row_id += 1\n",
        "  fig_id = -1\n",
        "  for mode in ['M1', 'M2', 'M3', 'M4']:\n",
        "    fig_id += 1\n",
        "    # fetch training data\n",
        "    df_subset1 = df[(df['Model'] == model) \u0026 (df['Training'] == 1)]\n",
        "    x_vals = np.array(df_subset1['Seen Examples'])\n",
        "    y_vals = np.array(df_subset1['Loss'])\n",
        "    # fetch test data\n",
        "    df_subset0 = df[(df['Model'] == model) \u0026 (df['Training'] == 0)]\n",
        "    x_test = np.array(df_subset0['Seen Examples'])\n",
        "    y_test = np.array(df_subset0['Loss'])\n",
        "\n",
        "    key = (domain, model)\n",
        "    law = scaling_laws[key][mode]\n",
        "    axes[row_id,fig_id].scatter(x_vals, y_vals, c='black', alpha=0.5)\n",
        "    axes[row_id,fig_id].scatter(x_test, y_test,  c='tab:orange')\n",
        "    xt, yt = law.loss_curve(min(df_subset1['Seen Examples']),  # min of fitting data\n",
        "                            max(df_subset0['Seen Examples']) * 3)  # max of extrapolation data\n",
        "\n",
        "    axes[row_id,fig_id].plot(xt, yt, color='tab:blue')\n",
        "    axes[row_id,fig_id].set_xscale('log')\n",
        "    axes[row_id,fig_id].set_title(mode.upper())\n",
        "\n",
        "    if fig_id == 0:\n",
        "      axes[row_id,fig_id].set_ylabel('P'+model)\n",
        "    if row_id == 4:\n",
        "      axes[row_id,fig_id].set_xlabel('Tokens')\n",
        "\n",
        "plt.setp(axes, xlim=(min(df_subset1['Seen Examples']), max(df_subset0['Seen Examples'])*3))\n",
        "fig.tight_layout()\n",
        "plt.show()\n",
        "filename = 'lm.pdf'\n",
        "fig.savefig(filename, dpi=200)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A6OD-eLgPVIN"
      },
      "source": [
        "## Big Bench"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6qJRi-6ffYT0"
      },
      "outputs": [],
      "source": [
        "domain = 'BB'\n",
        "df = df_all[df_all['Domain'] == domain]\n",
        "tasks = set(df['Task'])\n",
        "models = set(df['Model'])\n",
        "print('Tasks: ', tasks)\n",
        "print('Models: ', models)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DKfKtVHgBJ4X"
      },
      "outputs": [],
      "source": [
        "df"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TQG1XYNG9AcW"
      },
      "outputs": [],
      "source": [
        "tasks"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-N4Lru7ABLRT"
      },
      "outputs": [],
      "source": [
        "for task in tasks:\n",
        "  key = (domain, task)\n",
        "  print()\n",
        "  print(key)\n",
        "  # fetch training data\n",
        "  df_subset1 = df[(df['Task'] == task) \u0026 (df['Training'] == 1)]\n",
        "  x_vals = np.array(df_subset1['Seen Examples'])\n",
        "  y_vals = np.array(df_subset1['Loss'])  # rescaled already\n",
        "\n",
        "  fit_values = {x: y for x, y in zip(x_vals, y_vals)}\n",
        "  # fetch test data\n",
        "  df_subset0 = df[(df['Task'] == task) \u0026 (df['Training'] == 0)]\n",
        "  x_test = np.array(df_subset0['Seen Examples'])\n",
        "  y_test = np.array(df_subset0['Loss'])\n",
        "\n",
        "  # train all estimators\n",
        "  scaling_laws[key] = {}\n",
        "  errors[key] = {}\n",
        "  for mode in ['M1', 'M2', 'M3', 'M4']:\n",
        "    print(mode)\n",
        "    if mode == 'M1':\n",
        "      scaling_laws[key][mode] = M1(fit_values)\n",
        "    elif mode == 'M2':\n",
        "      scaling_laws[key][mode] = M2(fit_values)\n",
        "    elif mode == 'M3':\n",
        "      scaling_laws[key][mode] = M3(fit_values)\n",
        "    elif mode == 'M4':\n",
        "      scaling_laws[key][mode] = M4(fit_values, err_0=1.001,\n",
        "                                   update_err_0=True, up_bound=1.0)\n",
        "    # fit\n",
        "    scaling_laws[key][mode].estimate_scaling_params(verbose=0)\n",
        "    # report\n",
        "    if mode in ['M1', 'M2']:\n",
        "      print('beta, c, err_inf =\\t\\t %.2f, %0.2f, %0.2f' % (\n",
        "            scaling_laws[key][mode].beta,\n",
        "            scaling_laws[key][mode].c,\n",
        "            scaling_laws[key][mode].err_inf             \n",
        "            )\n",
        "      )\n",
        "    elif mode == 'M3':\n",
        "      print('beta, c, gamma =\\t\\t %.2f, %0.2f, %0.2f' %(\n",
        "            scaling_laws[key][mode].beta,\n",
        "            scaling_laws[key][mode].c,\n",
        "            scaling_laws[key][mode].gamma              \n",
        "            )\n",
        "      )\n",
        "    else:\n",
        "      print('beta, c, alpha, err_inf =\\t %.2f, %0.2f, %0.2f, %0.2f' %(\n",
        "            scaling_laws[key][mode].beta,\n",
        "            scaling_laws[key][mode].c,\n",
        "            scaling_laws[key][mode].alpha,\n",
        "            scaling_laws[key][mode].err_inf\n",
        "            )\n",
        "      )                    \n",
        "\n",
        "    # record error\n",
        "    errors[key][mode] = get_error(scaling_laws[key][mode],\n",
        "                                    x_test, y_test)\n",
        "    print('Extrapolation Loss =\\t\\t %.4f +- %.5f' %(\n",
        "        errors[key][mode][0], errors[key][mode][1])\n",
        "    )\n",
        "    print()\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "revisiting_neural_scaling_laws.ipynb",
      "private_outputs": true,
      "provenance": [
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
